#%% 
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--num_samples', type=int, default=32)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--pair_sampling', action='store_true')
args = parser.parse_args()

#%% 
# Load and split data
file = 'Taiwan_data_ENG_95.csv'
data = pd.read_csv(file, encoding='utf-8')

#%% 
Y = np.array(data['Flag'])
X = np.array(data.drop(['Flag'], axis=1))
X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, test_size=0.2, random_state=70)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=42)

# Data scaling
num_features = X_train.shape[1]
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)


#%% Load Model
import pickle
import os.path
import lightgbm as lgb
from lightgbm import log_evaluation, early_stopping
with open('bank model.pkl', 'rb') as f:
    model = pickle.load(f)


#%% Load Surrogate
import torch
import torch.nn as nn
from fastshap.utils import MaskLayer1d
from fastshap import Surrogate, KLDivLoss

# Select device
device = torch.device('cuda')
surr = torch.load('bank surrogate.pt').to(device)
surrogate = Surrogate(surr, num_features)

#%% Load FastSHAP
from simshap.fastshap_plus import FastSHAP
explainer_fastshap = torch.load('bank fastshap.pt').to(device)
fastshap = FastSHAP(explainer_fastshap, surrogate, normalization='additive',
                    link=nn.Identity())
#%% Train Default
from simshap.simshap_sampling import SimSHAPSampling
import sys
sys.path.append('..')
from models import SimSHAPTabular
# Check for model
# Create explainer model
explainer = SimSHAPTabular(in_dim=num_features, hidden_dim=512, out_dim=2).to(device)

# Set up SimSHAP object
simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
# Train
simshap.train(
    X_train,
    X_val[:100],
    batch_size=args.batch_size,
    num_samples=args.num_samples,
    max_epochs=args.epochs,
    paired_sampling=args.pair_sampling,
    lr=args.lr,  
    bar=False,
    validation_samples=128,
    verbose=True, 
    lookback=10,
    lr_factor=0.5)

# Save explainer
explainer.cpu()
torch.save(explainer, 'bank simshap ablation.pt')
explainer.to(device)
